Skip to content

Conversation

@ClarkChin08
Copy link

No description provided.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.

Key Changes:

  • Added SubgroupLayoutQK template parameter to the collective mainloop and kernel interfaces
  • Implemented causal masking logic that applies -INFINITY to attention scores beyond the causal boundary
  • Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Implements causal mask logic and removes the static assertion that previously blocked causal mask usage
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Adds subgroup layout type alias and computes sequence coordinates for causal masking
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Adds SubgroupLayoutQK template parameter to mainloop type
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp Conditionally selects causal or non-causal kernel based on is_causal option

@ClarkChin08 ClarkChin08 force-pushed the fa_causal_mask branch 2 times, most recently from bb07ccc to 836f2c4 Compare November 10, 2025 08:14
Signed-off-by: Chen, Xi2 <[email protected]>
}
}
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClarkChin08 Thanks for updating this! By the way, we can make the code even cleaner by including the block offset in cS_thread itself. Something like this should do it:

Tensor gP = local_tile(cP, TileShapeQK{}, blk_qv);
auto cS_thread = thr_mma_qk.partition_C(gP);

Then you don't need to do the blocking calculations here; instead row_idx = get<0>(cS_thread(i)), col_idx = get<1>(cS_thread(i)).

Copy link
Author

@ClarkChin08 ClarkChin08 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @petercad , I changed to use local_tile to get global col and row indices.


CUTE_HOST_DEVICE constexpr auto
get_atom_layout_mnk() const {
return atom_layout_mnk_;
Copy link

@petercad petercad Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding a new atom_layout_mnk_ member:

Suggested change
return atom_layout_mnk_;
return AtomLayoutMNK{};

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

auto discard_seq_coord = s.seq_len_qo - offset;
auto full_tile_offset = s.seq_len_kv - offset;

int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ));
Copy link

@petercad petercad Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sub_group_id / get<1>(shape(SubgroupLayoutQK{})) part is making a strong assumption about how subgroup tiles are arranged within the workgroup tile (K-major). We need to either add a static_assert for this condition, or (better) use CuTe layout algebra to calculate the subgroup Q offset. For instance:

auto cS = make_identity_tensor(take<0,2>(TiledMMAKQ{}.tile_size()));
auto tScS = TiledMMAKQ{}.get_slice(thread_idx).partition_C(cS);
auto q_offset_wi = get<0>(tScS(0));     /* Q offset for thread */
auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0);    /* Q offset for SG */

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @petercad. I'm now implementing the algebraic approach you suggested for calculating q_offset_sg.

Signed-off-by: Chen, Xi2 <[email protected]>
Signed-off-by: Chen, Xi2 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants